
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>ViT解读 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="Swin Transformer解读" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html" />
    <link rel="prev" title="Transformer 解读" href="Transformer%20%E8%A7%A3%E8%AF%BB.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.1%20ResNet.html">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第十章：常见代码解读
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Transformer%20%E8%A7%A3%E8%AF%BB.html">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第十章/ViT解读.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第十章/ViT解读.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第十章/ViT解读.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   ViT 的整体流程
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-embedding-linear-projection">
   切分和映射 Patch Embedding + Linear Projection
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#class-token-postional-embedding">
   分类表征和位置信息 Class Token + Postional Embedding
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#transformer-encoder">
   Transformer Encoder
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#multi-head-attention">
     Multi-head Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#mlp">
     MLP
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#layer-norm">
     Layer Norm
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     Transformer Encoder 完整代码
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   ViT 完整代码
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>ViT解读</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   ViT 的整体流程
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-embedding-linear-projection">
   切分和映射 Patch Embedding + Linear Projection
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#class-token-postional-embedding">
   分类表征和位置信息 Class Token + Postional Embedding
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#transformer-encoder">
   Transformer Encoder
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#multi-head-attention">
     Multi-head Attention
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#mlp">
     MLP
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#layer-norm">
     Layer Norm
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     Transformer Encoder 完整代码
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   ViT 完整代码
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="vit">
<h1>ViT解读<a class="headerlink" href="#vit" title="永久链接至标题">#</a></h1>
<p>
<font size=3><b>[ViT] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.</b></font>
<br>
<font size=2>Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.</font>
<br>
<font size=2>ICLR 2021.</font>
<a href='https://arxiv.org/pdf/2010.11929v2.pdf'>[paper]</a> <a href='https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py'>[code]</a> 
<br>
<font size=3>解读者：牛志康，西安电子科技大学本科生，Datawhale成员；小饭同学，香港城市大学研究生
</font>
<br>
</p>
<section id="id1">
<h2>前言<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
<p>Transformer 已经成为自然语言处理任务的一种基础网络，但它在计算机视觉中的应用仍然有限。因为 Transformer 对序列进行建模，如果我们将图像种的每一个像素都作为序列中的元素，因为序列的大小与图片的大小呈平方关系，将导致计算量大大增加。现有的工作要么是将注意力与卷积网络结合使用，要么用注意力机制替换 CNN 的某些组件或者降低图片的序列长度。这些改进都是基于 convolutional neural network (CNN) 卷积神经网络构建的，于是人们就在希望有一种完全基于 Transformer 的骨干网络，可以拥有 Transformer 全局建模的特性也可以不过多修改原始 Transformer 的结构。基于这种 motivation，才出现了 Vision Transformer (ViT) 这篇优秀的工作。</p>
<p>本文将从原理和代码实现上进行讲解，结合本课程需求，我们将着重讲解代码的实现，论文中更多的细节还请各位同学详细阅读原论文或关注 Whalepaper 后续的论文精读。</p>
</section>
<section id="id2">
<h2>ViT 的整体流程<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p>如下图所示，ViT 的主要思想是将图片分成一个一个的小 <code class="docutils literal notranslate"><span class="pre">patch</span></code>，将每一个 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 作为序列的元素输入 Transformer 中进行计算。
<img src="figures/vit_framework.png" align="center" style="zoom:80%;" /></p>
<p>其具体流程如下：</p>
<ol class="simple">
<li><p><strong>切分和映射</strong>：对一张标准图像，我们首先将图片切分成一个一个小的 <code class="docutils literal notranslate"><span class="pre">patch</span></code>，然后将它们的维度拉平 <code class="docutils literal notranslate"><span class="pre">Flatten</span></code> 为一维的向量，最后我们将这些向量通过线性映射 <code class="docutils literal notranslate"><span class="pre">Linear</span> <span class="pre">Project</span></code> <span class="math notranslate nohighlight">\(\mathbf{E}\)</span> 到维度为 <span class="math notranslate nohighlight">\(D\)</span> 的空间。</p></li>
<li><p><strong>分类表征和位置信息</strong>：<strong>分类表征</strong>：为了实现图像分类，我们在得到的向量中需要加入一个 <code class="docutils literal notranslate"><span class="pre">classs</span> <span class="pre">token</span></code>  <span class="math notranslate nohighlight">\(\mathbf{x}_\text{class}\)</span> 作为分类表征（如上图中标注 <span class="math notranslate nohighlight">\(*\)</span>的粉色向量所示）。<strong>位置信息</strong>：图像和文本一样也需要注意顺序问题，因此作者通过 <code class="docutils literal notranslate"><span class="pre">Position</span> <span class="pre">Embedding</span></code> <span class="math notranslate nohighlight">\(\mathbf{E}_{pos}\)</span> 加入位置编码信息（如上图中标注 <span class="math notranslate nohighlight">\(0-9\)</span> 的紫色向量所示）。</p></li>
<li><p><strong>Transformer Encoder</strong>：然后我们将经过上面操作的 <code class="docutils literal notranslate"><span class="pre">token</span></code> 送入 <code class="docutils literal notranslate"><span class="pre">Transformer</span> <span class="pre">Encoder</span></code>。这里的 <code class="docutils literal notranslate"><span class="pre">Transformer</span> <span class="pre">Encoder</span></code> 和 <code class="docutils literal notranslate"><span class="pre">Transformer</span> <span class="pre">(Attention</span> <span class="pre">is</span> <span class="pre">All</span> <span class="pre">You</span> <span class="pre">Need)</span></code> 文章中实现基本一致，主要是通过多头注意力机制，对 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 之间进行全局的信息提取。</p></li>
<li><p><strong>输出与分类</strong>：对于分类任务，我们只需要获得 <code class="docutils literal notranslate"><span class="pre">class</span> <span class="pre">token</span></code> 经过 <code class="docutils literal notranslate"><span class="pre">Transformer</span> <span class="pre">Encoder</span></code> 得到的输出，加一个 <code class="docutils literal notranslate"><span class="pre">MLP</span> <span class="pre">Head</span></code> 进行分类学习。</p></li>
</ol>
<p>我们论文代码的讲解也将按照上面的流程，对重要模块进行讲解，我们所展示的ViT代码示例来源于<a class="reference external" href="https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py">rwightman/timm</a>并进行了部分简化，在此感谢每一位开源贡献者所作出的贡献。</p>
</section>
<section id="patch-embedding-linear-projection">
<h2>切分和映射 Patch Embedding + Linear Projection<a class="headerlink" href="#patch-embedding-linear-projection" title="永久链接至标题">#</a></h2>
<p>对一张标准图像 <span class="math notranslate nohighlight">\(\mathbf{x}\)</span>，其分辨率为 <span class="math notranslate nohighlight">\(H \times W \times C\)</span>。为了方便讨论，我们取 ViT 的标准输入 <span class="math notranslate nohighlight">\(H \times W \times C = 224 \times 224 \times 3\)</span> 进行一些具体维度的讲解。通过切分操作，我们将整个图片分成多个 <code class="docutils literal notranslate"><span class="pre">patch</span></code> <span class="math notranslate nohighlight">\(\mathbf{x}_p\)</span>，其大小为 $<span class="math notranslate nohighlight">\(P \times P \times C = 16 \times 16 \times 3 = 768。\)</span><span class="math notranslate nohighlight">\( 这样，一共可以得到 `Patch` 的数量为 \)</span><span class="math notranslate nohighlight">\(N={(H \times W)}/{(P \times P)} = {(224 \times 224)}/{(16 \times 16)} = {(224 / 16)}\times {(224 / 16)} = 14 \times 14 = 196。\)</span><span class="math notranslate nohighlight">\( 所以，我们将一张 \)</span>224 \times 224 \times 3<span class="math notranslate nohighlight">\( 的标准图片， 通过转换得到了 \)</span>196<span class="math notranslate nohighlight">\( 个 `patch`，每个 `patch` 的维度是 \)</span>768$。</p>
<p>对得到的 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 通过 <span class="math notranslate nohighlight">\(\mathbf{E} \in {\mathbb{R}^{768 \times D}}\)</span> 进行线性映射到维度 <span class="math notranslate nohighlight">\(D\)</span>，我们将映射后的 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 叫做 <code class="docutils literal notranslate"><span class="pre">token</span></code>，以便于和原本 Transformer 的术语进行统一（代码中默认的 <span class="math notranslate nohighlight">\(D\)</span> 仍然为 <span class="math notranslate nohighlight">\(768\)</span>。我们认为，为了不损失信息，这里 <span class="math notranslate nohighlight">\(D\)</span> 满足大于等于 <span class="math notranslate nohighlight">\(768\)</span> 即可）。对应文中公式，上述操作可以表示为：
$<span class="math notranslate nohighlight">\(
\begin{align}
[\mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}], \quad \mathbf{E}\in\mathbb{R}^{(P^2\cdot C)\times D}。
\end{align}
\)</span>$</p>
<p>以上是按照原论文对<strong>切分和映射</strong>的讲解，在实际的代码实现过程中，切分和映射实际上是通过一个二维卷积 <code class="docutils literal notranslate"><span class="pre">nn.Conv2d()</span></code> 一步完成的。为了实现一步操作，作者将卷积核的大小 <code class="docutils literal notranslate"><span class="pre">kernal_size</span></code> 直接设置为了 <code class="docutils literal notranslate"><span class="pre">patch_size</span></code>，即 <span class="math notranslate nohighlight">\(P=16\)</span>。然后，将卷积核的步长 <code class="docutils literal notranslate"><span class="pre">stride</span></code> 也设置为了同样的 <code class="docutils literal notranslate"><span class="pre">patch_size</span></code>，这样就实现了不重复的切割图片。而卷积的特征输入和输出维度，分别设为了 <span class="math notranslate nohighlight">\(C=3\)</span> 和 <span class="math notranslate nohighlight">\(D=768\)</span>，对应下方代码的 <code class="docutils literal notranslate"><span class="pre">in_c</span></code> 和 <code class="docutils literal notranslate"><span class="pre">embed_dim</span></code>。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_c</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span>
</pre></div>
</div>
<p>一张 <span class="math notranslate nohighlight">\(1 \times 3 \times 224 \times 224\)</span> 的图像（其中 <span class="math notranslate nohighlight">\(1\)</span> 是 <code class="docutils literal notranslate"><span class="pre">batch_size</span></code> 的维度），经过上述卷积操作得到 <span class="math notranslate nohighlight">\(1 \times 768 \times 14 \times 14\)</span> 的张量。（代码中将 <span class="math notranslate nohighlight">\(14 \times 14 = 196\)</span> 当作 <code class="docutils literal notranslate"><span class="pre">grid</span></code> 的个数，即 <code class="docutils literal notranslate"><span class="pre">grid_size=(14,</span> <span class="pre">14)</span></code>）然后，对其进行拉平 <code class="docutils literal notranslate"><span class="pre">flatten(2)</span></code> 得到 <span class="math notranslate nohighlight">\(1 \times 768 \times 196\)</span> 的张量。因为 Transformer 需要将序列维度调整到前面，我们再通过 <code class="docutils literal notranslate"><span class="pre">transpose(1,</span> <span class="pre">2)</span></code> 调整特征和序列维度，最终得到的张量大小为 <span class="math notranslate nohighlight">\(1 \times 196 \times 768\)</span>。切分、映射、拉平和维度调整统统经过下面一步操作得到：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
</pre></div>
</div>
<p>在代码中，这些操作全部被写在名为 <code class="docutils literal notranslate"><span class="pre">PatchEmbed</span></code> 的模块中，其具体的实现如下所示：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">PatchEmbed</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Image --&gt; Patch Embedding --&gt; Linear Proj --&gt; Pos Embedding</span>
<span class="sd">    Image size -&gt; [224,224,3]</span>
<span class="sd">    Patch size -&gt; 16*16</span>
<span class="sd">    Patch num -&gt; (224^2)/(16^2)=196</span>
<span class="sd">    Patch dim -&gt; 16*16*3 =768</span>
<span class="sd">    Patch Embedding: [224,224,3] -&gt; [196,768]</span>
<span class="sd">    Linear Proj: [196,768] -&gt; [196,768]</span>
<span class="sd"> 	Positional Embedding: [197,768] -&gt; [196,768]</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">in_c</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span> <span class="n">norm_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Args:</span>
<span class="sd">            img_size: 默认参数224</span>
<span class="sd">            patch_size: 默认参数是16</span>
<span class="sd">            in_c: 输入的通道数</span>
<span class="sd">            embed_dim: 16*16*3 = 768</span>
<span class="sd">            norm_layer: 是否使用norm层，默认为否</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="n">img_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="n">img_size</span><span class="p">)</span> <span class="c1"># -&gt; img_size = (224,224)</span>
        <span class="n">patch_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">)</span> <span class="c1"># -&gt; patch_size = (16,16)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="n">img_size</span> <span class="c1"># -&gt; (224,224)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="c1"># -&gt; (16,16)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">grid_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># -&gt; grid_size = (14,14)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">grid_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">grid_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># -&gt; num_patches = 196</span>
        <span class="c1"># Patch+linear proj的这个操作 [224,224,3] --&gt; [14,14,768]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_c</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span>
        <span class="c1"># 判断是否有norm_layer层，要是没有不改变输入</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">)</span> <span class="k">if</span> <span class="n">norm_layer</span> <span class="k">else</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 计算各个维度的大小</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
        <span class="k">assert</span> <span class="n">H</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="ow">and</span> <span class="n">W</span> <span class="o">==</span> <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> \
            <span class="sa">f</span><span class="s2">&quot;Input image size (</span><span class="si">{</span><span class="n">H</span><span class="si">}</span><span class="s2">*</span><span class="si">{</span><span class="n">W</span><span class="si">}</span><span class="s2">) doesn&#39;t match model (</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">*</span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2">).&quot;</span>
        
        <span class="c1"># flatten: [B, C, H, W] -&gt; [B, C, HW], flatten(2)代表的是从2位置开始展开</span>
        <span class="c1"># eg: [1,3,224,224] --&gt; [1,768,14,14] -flatten-&gt;[1,768,196]</span>
        <span class="c1"># transpose: [B, C, HW] -&gt; [B, HW, C]</span>
        <span class="c1"># eg: [1,768,196] -transpose-&gt; [1,196,768]</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
<p>**在默认情况下，这一步是不进行 <code class="docutils literal notranslate"><span class="pre">layer_norm</span></code> 操作的，即它被设置为 <code class="docutils literal notranslate"><span class="pre">nn.Identity()</span></code>。对于 <code class="docutils literal notranslate"><span class="pre">layer_norm</span></code>，我们会在下面进行详细的讲解。</p>
</section>
<section id="class-token-postional-embedding">
<h2>分类表征和位置信息 Class Token + Postional Embedding<a class="headerlink" href="#class-token-postional-embedding" title="永久链接至标题">#</a></h2>
<p>如下图所示，左侧灰色部分为加入分类表征，中间紫色部分为加入位置信息。
<img alt="图源amaarora" src="../_images/vit_cls_pos.png" /></p>
<p><strong>分类表征：Class Token</strong>
为了实现图像分类，我们在切分和映射后的向量 <span class="math notranslate nohighlight">\([\mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}]\)</span> 中加入一个 <code class="docutils literal notranslate"><span class="pre">class</span> <span class="pre">token</span></code>  <span class="math notranslate nohighlight">\(\mathbf{x}_\text{class} \in \mathbb{R}^{D}\)</span> 作为分类表征（如上图中最左侧深灰色框所示）。将这个表征放置在序列的第一个位置上，我们就得到一个维度为 <span class="math notranslate nohighlight">\((196+1) \times 768\)</span> 的新张量：
$<span class="math notranslate nohighlight">\(
\begin{align}
[\mathbf{x}_{\text{class}}; \mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] 
\end{align}
\)</span>$
对于具体的代码实现，我们通过 <code class="docutils literal notranslate"><span class="pre">nn.Parameter(torch.zeros(1,</span> <span class="pre">1,</span> <span class="pre">768))</span></code> 实例化一个可学习的 <code class="docutils literal notranslate"><span class="pre">cls_token</span></code>，然后将这个 <code class="docutils literal notranslate"><span class="pre">cls_token</span></code> 按照 <code class="docutils literal notranslate"><span class="pre">batch_size</span> <span class="pre">=</span> <span class="pre">x.shape[0]</span></code> 进行复制，最后将其和之前经过切分和映射的 <code class="docutils literal notranslate"><span class="pre">x</span></code> 并在一起 <code class="docutils literal notranslate"> <span class="pre">torch.cat((cls_token,</span> <span class="pre">x),</span> <span class="pre">dim=1)</span></code>。其完整代码，如下所示：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">cls_token</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">768</span><span class="p">))</span> <span class="c1"># -&gt; cls token</span>
<span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span> <span class="c1"># 初始化</span>
<span class="n">cls_token</span> <span class="o">=</span> <span class="n">cls_token</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># (1,1,768) -&gt; (128,1,768)</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">cls_token</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [128, 197, 768]</span>
</pre></div>
</div>
<p>**其实也可以不加入这个 <code class="docutils literal notranslate"><span class="pre">cls</span> <span class="pre">token</span></code>，我们可以对输出 <code class="docutils literal notranslate"><span class="pre">token</span></code> 做 <code class="docutils literal notranslate"><span class="pre">GAP(Global</span> <span class="pre">Average</span> <span class="pre">Pooling)</span></code>，然后对 <code class="docutils literal notranslate"><span class="pre">GAP</span></code> 的结果进行分类。</p>
<p><strong>位置信息：Postional Embedding</strong>
图像和文本一样也需要注意顺序问题，因此作者通过 <code class="docutils literal notranslate"><span class="pre">Position</span> <span class="pre">Embedding</span></code> <span class="math notranslate nohighlight">\(\mathbf{E}_{\text{pos}}\in\mathbb{R}^{(N + 1)\times D}\)</span> 加入位置编码信息。这个 <code class="docutils literal notranslate"><span class="pre">Position</span> <span class="pre">Embedding</span></code> 和上面得到的分类表征张量，直接相加：
$<span class="math notranslate nohighlight">\(
\begin{align}
\mathbf{z}_0 &amp;= [\mathbf{x}_{\text{class}}; \mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E};] + \mathbf{E}_{\text{pos}}, &amp; \mathbf{E}&amp;\in\mathbb{R}^{(P^2\cdot C)\times D}, \mathbf{E}_{\text{pos}}\in\mathbb{R}^{(N + 1)\times D}
\end{align}
\)</span>$</p>
<p>与 Transformer 使用余弦位置编码不同的是，ViT 通过<code class="docutils literal notranslate"><span class="pre">nn.Parameter()</span></code>实现了一个可以学习的位置编码。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">num_patches</span> <span class="o">=</span> <span class="mi">196</span>
<span class="n">pos_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_patches</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">768</span><span class="p">))</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">pos_embed</span>
</pre></div>
</div>
<p>**这里 <code class="docutils literal notranslate"><span class="pre">pos_embed</span></code> 在 <code class="docutils literal notranslate"><span class="pre">batch_size</span></code> 的维度进行了 boardcast，所以所有的样本都是同样的 <code class="docutils literal notranslate"><span class="pre">pos_embed</span></code>。</p>
</section>
<section id="transformer-encoder">
<h2>Transformer Encoder<a class="headerlink" href="#transformer-encoder" title="永久链接至标题">#</a></h2>
<p>下一步，我们只需要将序列 <span class="math notranslate nohighlight">\(\mathbf{z}_0\)</span> 输入 Transformer Encoder 即可。如下图所示，每个 Transformer Encoder 由 Multi-head Attention、MLP、Norm (Layer Norm,LN) 并外加 shortcut 连接实现。
$<span class="math notranslate nohighlight">\(
\begin{align}
\mathbf{z}'_l &amp;= \text{MSA}(\text{LN}(\mathbf{z}_{l-1})) + \mathbf{z}_{l-1}, &amp; l &amp;=1\dots L, \\
\mathbf{z}_l &amp;= \text{MLP}(\text{LN}(\mathbf{z}'_l)) + \mathbf{z}'_l,  &amp; l &amp;=1\dots L, \\
\mathbf{y} &amp;= \text{LN}(\mathbf{z}_L^0)
\end{align}
\)</span>$</p>
<center>
<figure>
<img src="figures/vit_transformer.png", style="zoom:37%;" />
</figure>
</center>
下面我们将对这些模块逐一进行讲解。
<section id="multi-head-attention">
<h3>Multi-head Attention<a class="headerlink" href="#multi-head-attention" title="永久链接至标题">#</a></h3>
<p>Multi-head Attention 或者叫做 Multi-head Self-Attention (MSA) 是由多个 Self-attention (SA) 模块组成，它们的框图可由下面所示，其中左侧为 SA，右侧为 MSA。</p>
<center>
<figure>
<img src="figures/vit_sa.png", style="zoom:37%;" />
·
·
·
·
·
·
<img src="figures/vit_msa.png", style="zoom:30%;" />
</figure>
</center>
<p>对于一个标准的 SA 模块，我们通过对输入张量 <span class="math notranslate nohighlight">\(\mathbf{z}\)</span> 进行一个映射 <span class="math notranslate nohighlight">\(\mathbf{W_{SA}}\)</span> 得到 <span class="math notranslate nohighlight">\(Q, K, V\)</span>
$<span class="math notranslate nohighlight">\(
[Q, K, V] = \mathbf{z} \mathbf{W}_{\text{SA}}.
\)</span><span class="math notranslate nohighlight">\(
对于 MSA，我们需要对其输入再次进行切分为 \)</span>k<span class="math notranslate nohighlight">\( 个部分 （\)</span>k=<span class="math notranslate nohighlight">\(``self.num_heads``），而每个部分的维度为原本维度的 \)</span>k$ 分之一，即 <code class="docutils literal notranslate"><span class="pre">C</span> <span class="pre">//</span> <span class="pre">self.num_heads</span></code>。然后，将维度进行调整，即 <code class="docutils literal notranslate"><span class="pre">q,</span> <span class="pre">k,</span> <span class="pre">v</span></code> 到第 1 个维度， 批大小 <code class="docutils literal notranslate"><span class="pre">batch_size</span></code> 为第 2 个维度，头的数量数量 <code class="docutils literal notranslate"><span class="pre">num_heads</span></code> 为第 3 个维度，切分块的数量 <code class="docutils literal notranslate"><span class="pre">num_patches</span></code> 和每个头的特征维度 <code class="docutils literal notranslate"><span class="pre">embed_dim_per_head</span></code> 为最后两个维度。这种维度调整，将方便提取 <code class="docutils literal notranslate"><span class="pre">q,</span> <span class="pre">k,</span> <span class="pre">v</span></code>，以及后面的注意力计算。上述步骤在代码中对应：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="bp">self</span><span class="o">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">)</span>
<span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">C</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
<span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>  <span class="c1"># seperate q, k, v</span>
</pre></div>
</div>
<p>现在，如果我们将每一个 <code class="docutils literal notranslate"><span class="pre">head</span></code>，看作一个独立的计算单元。我们可以对每一个<code class="docutils literal notranslate"><span class="pre">head</span></code> 进行标准的 SA 计算
$<span class="math notranslate nohighlight">\(
Attention(Q, K, V) = softmax(\frac{Q K^T}{\sqrt {D_k}}) \cdot V
\)</span><span class="math notranslate nohighlight">\(
然后，这些 ``head`` 会被拼接在一起，计算最终的输出：
\)</span><span class="math notranslate nohighlight">\(
\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat}(\mathrm{head_1}, ...,
\mathrm{head_h})W^O    \\
    \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
\)</span>$</p>
<p>其中 <span class="math notranslate nohighlight">\(W^O\)</span> 代表的是线性变换层，<span class="math notranslate nohighlight">\(head_i\)</span> 代表的是每个 <code class="docutils literal notranslate"><span class="pre">head</span></code> 的输出，其中 <span class="math notranslate nohighlight">\(W^Q_i\)</span>，<span class="math notranslate nohighlight">\(W^K_i\)</span>, <span class="math notranslate nohighlight">\(W^V_i\)</span>，等价于每个 <code class="docutils literal notranslate"><span class="pre">head</span></code> 的线性映射权重（如上面计算 <code class="docutils literal notranslate"><span class="pre">qkv</span></code>所讲，实际代码实现中，我们会先一起计算 <code class="docutils literal notranslate"><span class="pre">qkv</span></code>，再进行 <code class="docutils literal notranslate"><span class="pre">head</span></code> 的切分）。如果按照默认实现，一般切分为 <span class="math notranslate nohighlight">\(k=8\)</span> 个头，其中 <span class="math notranslate nohighlight">\(D_k=D/k = 768/8=96\)</span>，是为了归一化点乘的结果。</p>
<p>在代码实现的时候，作者充分考虑了多头的并行计算。通过点乘的形式对所有的 <code class="docutils literal notranslate"><span class="pre">head</span></code> 一起计算相关性 <code class="docutils literal notranslate"><span class="pre">(q</span> <span class="pre">&#64;</span> <span class="pre">k.transpose(-2,</span> <span class="pre">-1))</span></code>，然后经过 <code class="docutils literal notranslate"><span class="pre">softmax</span></code> 得到权重 <code class="docutils literal notranslate"><span class="pre">attn</span></code> （这些权重的维度为 <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">num_heads,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1]</span></code>）。
之后将这些权重 <code class="docutils literal notranslate"><span class="pre">attn</span></code> 和 <code class="docutils literal notranslate"><span class="pre">v</span></code> （其维度为 <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">num_heads,</span> <span class="pre">num_patches+1,</span> <span class="pre">embed_dim_per_head]</span></code>） 进行点乘，得到注意力的输出结果。这里在点乘的时候，我们只需要看 <code class="docutils literal notranslate"><span class="pre">attn</span></code> 和 <code class="docutils literal notranslate"><span class="pre">v</span></code>的最后两个维度，分别为<code class="docutils literal notranslate"><span class="pre">[num_patches</span> <span class="pre">+</span> <span class="pre">1,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1]</span></code> 和 <code class="docutils literal notranslate"><span class="pre">[num_patches+1,</span> <span class="pre">embed_dim_per_head]</span></code>，维持其他维度不变，我们可以得到输出的结果维度为 <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">num_heads,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1,</span> <span class="pre">embed_dim_per_head]</span></code>。
最后，我们通过将特征维度和多头维度交换 <code class="docutils literal notranslate"><span class="pre">transpose(1,</span> <span class="pre">2)</span></code> 和 重组第2个及后面所有的维度 <code class="docutils literal notranslate"><span class="pre">reshape(B,</span> <span class="pre">N,</span> <span class="pre">C)</span></code>，就可以得到维度为 <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1,</span> <span class="pre">total_embed_dim]</span></code> 和上面公式相同的并行多头计算结果。其完整实现如下所示</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Attention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">dim</span><span class="p">,</span>   <span class="c1"># 输入token的dim</span>
                 <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="c1"># attention head的个数</span>
                 <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># 是否使用qkv bias</span>
                 <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">attn_drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">proj_drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Attention</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span>
        <span class="c1"># 计算每一个head处理的维度head_dim = dim // num_heads --&gt; 768/8 = 96</span>
        <span class="n">head_dim</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">//</span> <span class="n">num_heads</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="ow">or</span> <span class="n">head_dim</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span> <span class="c1"># 根下dk操作</span>
        <span class="c1"># 使用nn.Linear生成w_q,w_k,w_v，因为本质上每一个变换矩阵都是线性变换，</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">attn_drop_ratio</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">proj_drop_ratio</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># [batch_size, num_patches + 1, total_embed_dim]</span>
        <span class="c1"># total_embed_dim不是一开始展开的那个维度，是经过了一个线性变换层得到的</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>

        <span class="c1"># [batch_size, num_patches+1, total_embed_dim] -qkv()-&gt; [batch_size, num_patches + 1, 3 * total_embed_dim]</span>
        <span class="c1"># reshape: -&gt; [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]</span>
        <span class="c1"># permute: -&gt; [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]</span>
        <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">C</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
        <span class="c1"># q,k,v = [batch_size, num_heads, num_patches + 1, embed_dim_per_head]</span>
        <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>  <span class="c1"># make torchscript happy (cannot use tensor as tuple)</span>

        <span class="c1"># transpose(-2,-1)在最后两个维度进行操作，输入的形状[batch_size,num_heads,num_patches+1,embed_dim_per_head]</span>
        <span class="c1"># transpose: -&gt; [batch_size, num_heads, embed_dim_per_head, num_patches + 1]</span>
        <span class="c1"># @: multiply -&gt; [batch_size, num_heads, num_patches + 1, num_patches + 1]</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="p">(</span><span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>

        <span class="c1"># @: multiply -&gt; [batch_size, num_heads, num_patches + 1, embed_dim_per_head]</span>
        <span class="c1"># transpose: -&gt; [batch_size, num_patches + 1, num_heads, embed_dim_per_head]</span>
        <span class="c1"># reshape: -&gt; [batch_size, num_patches + 1, total_embed_dim]</span>
        <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">attn</span> <span class="o">@</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
<section id="mlp">
<h3>MLP<a class="headerlink" href="#mlp" title="永久链接至标题">#</a></h3>
<p>MLP层类似于原始Transformer中的Feed Forward Network。</p>
<blockquote>
<div><p>In ViT, only MLP layers are local and translationally equivariant, while the self-attention layers are global.</p>
</div></blockquote>
<p>为了理解这句话，即 MLP 只对局部信息进行操作，我们需要强调 <code class="docutils literal notranslate"><span class="pre">nn.Linear()</span></code> 操作只对输入张量的最后一个维度进行操作。那么，对于输入维度为 <code class="docutils literal notranslate"><span class="pre">[batch_size,</span> <span class="pre">num_patches</span> <span class="pre">+</span> <span class="pre">1,</span> <span class="pre">total_embed_dim]</span></code>，学习到的线性层对于所有 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 都是一样的。所以，它是一个局部信息的建模。对于 Attention，因为它是在不同的 <code class="docutils literal notranslate"><span class="pre">patch</span></code> 层面或者不同的序列层面进行建模，所以是全局信息建模。因此，作者使用了 MLP 和 Attention 一起进行局部和全局信息的提取。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Mlp</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    in_features --&gt; hidden_features --&gt; out_features</span>
<span class="sd">    论文实现时：in_features.shape = out_features.shape</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_features</span><span class="p">,</span> <span class="n">hidden_features</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">act_layer</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">,</span> <span class="n">drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="c1"># 用or实现了或操作，当hidden_features/out_features为默认值None时</span>
        <span class="c1"># 此时out_features/hidden_features=None or in_features = in_features</span>
        <span class="c1"># 当对out_features或hidden_features进行输入时，or操作将会默认选择or前面的</span>
        <span class="c1"># 此时out_features/hidden_features = out_features/hidden_features</span>
        <span class="n">out_features</span> <span class="o">=</span> <span class="n">out_features</span> <span class="ow">or</span> <span class="n">in_features</span>
        <span class="n">hidden_features</span> <span class="o">=</span> <span class="n">hidden_features</span> <span class="ow">or</span> <span class="n">in_features</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="p">,</span> <span class="n">hidden_features</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">act_layer</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_features</span><span class="p">,</span> <span class="n">out_features</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">drop</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># in_features --&gt; hidden_features --&gt; out_features</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
<section id="layer-norm">
<h3>Layer Norm<a class="headerlink" href="#layer-norm" title="永久链接至标题">#</a></h3>
<p>Normalization 有很多种，但是它们都有一个共同的目的，那就是把输入转化成均值为 0 方差为 1 的数据（或者某个学习到的均值和方差）。我们在把数据送入激活函数之前进行 Normalization（归一化），因为我们不希望输入数据落在激活函数的饱和区。</p>
<p>Batch Norm 的作用是在对这批样本的同一维度特征做归一化，而 Layer Norm 的作用是对<strong>单个样本的所有维度特征做归一化</strong>。举一个简单的例子，对于通过编码的句子“我爱学习”，Batch Norm 是对这四个字进行归一化，而 Layer Norm 是对每个字本身的特征进行归一化。</p>
<p>对于 Layer Norm，其公式如下所示
$<span class="math notranslate nohighlight">\(L N\left(x_i\right)=\alpha \times \frac{x_i-u_L}{\sqrt{\sigma_L^2+\epsilon}}+\beta\)</span>$
可以通过 <code class="docutils literal notranslate"><span class="pre">nn.LayerNorm</span></code> 进行实现。</p>
</section>
<section id="id3">
<h3>Transformer Encoder 完整代码<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h3>
<p>整合上面 Multi-head Attention、MLP、Norm (Layer Norm,LN) 并外加 shortcut 连接代码，我们可以得到 Transformer Encoder 的完整代码。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    每一个Encoder Block的构成</span>
<span class="sd">    每个Encode Block的流程：norm1 --&gt; Multi-Head Attention --&gt; norm2 --&gt; MLP</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">dim</span><span class="p">,</span> <span class="c1"># 输入mlp的维度</span>
                 <span class="n">num_heads</span><span class="p">,</span> <span class="c1"># Multi-Head-Attention的头个数</span>
                 <span class="n">mlp_ratio</span><span class="o">=</span><span class="mf">4.</span><span class="p">,</span> <span class="c1"># hidden_features / in_features = mlp_ratio</span>
                 <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># q,k,v的生成是否使用bias</span>
                 <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="c1"># dropout的比例</span>
                 <span class="n">attn_drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="c1"># 注意力dropout的比例</span>
                 <span class="n">drop_path_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">act_layer</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">GELU</span><span class="p">,</span> <span class="c1"># 激活函数默认使用GELU</span>
                 <span class="n">norm_layer</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">):</span> <span class="c1"># Norm默认使用LayerNorm</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Block</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="c1"># 第一层normalization</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
        <span class="c1"># self.attention层的实现</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">Attention</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="n">qk_scale</span><span class="p">,</span><span class="n">attn_drop_ratio</span><span class="o">=</span><span class="n">attn_drop_ratio</span><span class="p">,</span> <span class="n">proj_drop_ratio</span><span class="o">=</span><span class="n">drop_ratio</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">drop_path</span> <span class="o">=</span> <span class="n">DropPath</span><span class="p">(</span><span class="n">drop_path_ratio</span><span class="p">)</span> <span class="k">if</span> <span class="n">drop_path_ratio</span> <span class="o">&gt;</span> <span class="mf">0.</span> <span class="k">else</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>
        <span class="c1"># 第二层normalization</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
        <span class="n">mlp_hidden_dim</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">mlp_ratio</span><span class="p">)</span> <span class="c1"># hidden_dim = dim * mlp_ratio</span>
        <span class="c1"># mlp实现</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">Mlp</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="n">dim</span><span class="p">,</span> <span class="n">hidden_features</span><span class="o">=</span><span class="n">mlp_hidden_dim</span><span class="p">,</span> <span class="n">act_layer</span><span class="o">=</span><span class="n">act_layer</span><span class="p">,</span> <span class="n">drop</span><span class="o">=</span><span class="n">drop_ratio</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 实现了两个残差连接</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
</section>
<section id="id4">
<h2>ViT 完整代码<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h2>
<p>对输入图像，进行切分和影射、加入分类表征和位置信息、经过 Transformer Encoder、然后添加一个分类头进行输出，我们就完成了 ViT 所有的代码。</p>
<p>完整的 ViT 主要模块流程，见下方 <code class="docutils literal notranslate"><span class="pre">VisionTransformer</span></code>。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">VisionTransformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">img_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span>
                 <span class="n">patch_size</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span>
                 <span class="n">in_c</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
                 <span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span>
                 <span class="n">embed_dim</span><span class="o">=</span><span class="mi">768</span><span class="p">,</span>
                 <span class="n">depth</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span>
                 <span class="n">num_heads</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span>
                 <span class="n">mlp_ratio</span><span class="o">=</span><span class="mf">4.0</span><span class="p">,</span>
                 <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
                 <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">representation_size</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">distilled</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
                 <span class="n">drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">attn_drop_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">drop_path_ratio</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">embed_layer</span><span class="o">=</span><span class="n">PatchEmbed</span><span class="p">,</span>
                 <span class="n">norm_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
                 <span class="n">act_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        Args:</span>
<span class="sd">            img_size (int, tuple): input image size</span>
<span class="sd">            patch_size (int, tuple): patch size</span>
<span class="sd">            in_c (int): number of input channels</span>
<span class="sd">            num_classes (int): number of classes for classification head</span>
<span class="sd">            embed_dim (int): embedding dimension</span>
<span class="sd">            depth (int): depth of transformer</span>
<span class="sd">            num_heads (int): number of attention heads</span>
<span class="sd">            mlp_ratio (int): ratio of mlp hidden dim to embedding dim</span>
<span class="sd">            qkv_bias (bool): enable bias for qkv if True</span>
<span class="sd">            qk_scale (float): override default qk scale of head_dim ** -0.5 if set</span>
<span class="sd">            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set</span>
<span class="sd">            distilled (bool): model includes a distillation token and head as in DeiT models</span>
<span class="sd">            drop_ratio (float): dropout rate</span>
<span class="sd">            attn_drop_ratio (float): attention dropout rate</span>
<span class="sd">            drop_path_ratio (float): stochastic depth rate</span>
<span class="sd">            embed_layer (nn.Module): patch embedding layer</span>
<span class="sd">            norm_layer: (nn.Module): normalization layer</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">VisionTransformer</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span>
        <span class="c1"># 每个patch的图像维度 = embed_dim</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>  <span class="c1"># num_features for consistency with other models</span>
        <span class="c1"># token的个数为1</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_tokens</span> <span class="o">=</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">distilled</span> <span class="k">else</span> <span class="mi">1</span>
        <span class="c1"># 设置激活函数和norm函数</span>
        <span class="n">norm_layer</span> <span class="o">=</span> <span class="n">norm_layer</span> <span class="ow">or</span> <span class="n">partial</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span>
        <span class="n">act_layer</span> <span class="o">=</span> <span class="n">act_layer</span> <span class="ow">or</span> <span class="n">nn</span><span class="o">.</span><span class="n">GELU</span>
        <span class="c1"># 对应的将图片打成patch的操作</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span> <span class="o">=</span> <span class="n">embed_layer</span><span class="p">(</span><span class="n">img_size</span><span class="o">=</span><span class="n">img_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">in_c</span><span class="o">=</span><span class="n">in_c</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span>
        <span class="n">num_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span><span class="o">.</span><span class="n">num_patches</span>
        <span class="c1"># 设置分类的cls_token</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">))</span>
        <span class="c1"># distilled 是Deit中的 这里为None</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">))</span> <span class="k">if</span> <span class="n">distilled</span> <span class="k">else</span> <span class="kc">None</span>
        <span class="c1"># pos_embedding 为一个可以学习的参数</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">pos_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_patches</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">))</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">drop_ratio</span><span class="p">)</span>

        <span class="n">dpr</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">drop_path_ratio</span><span class="p">,</span> <span class="n">depth</span><span class="p">)]</span>  <span class="c1"># stochastic depth decay rule</span>
        <span class="c1"># 使用nn.Sequential进行构建，ViT中深度为12</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="p">[</span>
            <span class="n">Block</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="n">mlp_ratio</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="n">qk_scale</span><span class="p">,</span>
                  <span class="n">drop_ratio</span><span class="o">=</span><span class="n">drop_ratio</span><span class="p">,</span> <span class="n">attn_drop_ratio</span><span class="o">=</span><span class="n">attn_drop_ratio</span><span class="p">,</span> <span class="n">drop_path_ratio</span><span class="o">=</span><span class="n">dpr</span><span class="p">[</span><span class="n">i</span><span class="p">],</span>
                  <span class="n">norm_layer</span><span class="o">=</span><span class="n">norm_layer</span><span class="p">,</span> <span class="n">act_layer</span><span class="o">=</span><span class="n">act_layer</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">depth</span><span class="p">)</span>
        <span class="p">])</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">)</span>

        <span class="c1"># Representation layer</span>
        <span class="k">if</span> <span class="n">representation_size</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">distilled</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">has_logits</span> <span class="o">=</span> <span class="kc">True</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span> <span class="o">=</span> <span class="n">representation_size</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">pre_logits</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">OrderedDict</span><span class="p">([</span>
                <span class="p">(</span><span class="s2">&quot;fc&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">representation_size</span><span class="p">)),</span>
                <span class="p">(</span><span class="s2">&quot;act&quot;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Tanh</span><span class="p">())</span>
            <span class="p">]))</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">has_logits</span> <span class="o">=</span> <span class="kc">False</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">pre_logits</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>

        <span class="c1"># Classifier head(s)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">head_dist</span> <span class="o">=</span> <span class="kc">None</span>
        <span class="k">if</span> <span class="n">distilled</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">head_dist</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">)</span> <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>

        <span class="c1"># Weight init</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pos_embed</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>

        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">0.02</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">_init_vit_weights</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># [B, C, H, W] -&gt; [B, num_patches, embed_dim]</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># [B, 196, 768]</span>
        <span class="c1"># [1, 1, 768] -&gt; [B, 1, 768]</span>
        <span class="n">cls_token</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">cls_token</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">cls_token</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># [B, 197, 768]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">cls_token</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_embed</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">dist_token</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
            <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">pre_logits</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">]</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward_features</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_dist</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">x</span><span class="p">,</span> <span class="n">x_dist</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_dist</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
            <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">jit</span><span class="o">.</span><span class="n">is_scripting</span><span class="p">():</span>
                <span class="c1"># during inference, return the average of both classifier predictions</span>
                <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">x_dist</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">x_dist</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>


<span class="k">def</span> <span class="nf">_init_vit_weights</span><span class="p">(</span><span class="n">m</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    ViT weight initialization</span>
<span class="sd">    :param m: module</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">):</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">trunc_normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.01</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">m</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
    <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">):</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">kaiming_normal_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">&quot;fan_out&quot;</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">m</span><span class="o">.</span><span class="n">bias</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
    <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">):</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">zeros_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
        <span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="o">.</span><span class="n">ones_</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">weight</span><span class="p">)</span>
</pre></div>
</div>
<blockquote>
<div><p>参考：</p>
<p>An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale https://arxiv.org/pdf/2010.11929.pdf
Attention Is All You Need https://arxiv.org/abs/1706.03762</p>
</div></blockquote>
</section>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="Transformer%20%E8%A7%A3%E8%AF%BB.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">Transformer 解读</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">Swin Transformer解读</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>